/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.parquet.reader;

import com.google.common.collect.ImmutableList;
import com.google.common.io.Resources;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.parquet.DataPage;
import io.trino.parquet.DataPageV2;
import io.trino.parquet.ParquetDataSource;
import io.trino.parquet.ParquetDataSourceId;
import io.trino.parquet.ParquetReaderOptions;
import io.trino.parquet.PrimitiveField;
import io.trino.parquet.metadata.ParquetMetadata;
import io.trino.plugin.base.type.DecodedTimestamp;
import io.trino.spi.block.Block;
import io.trino.spi.block.Fixed12Block;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Timestamps;
import io.trino.spi.type.Type;
import org.apache.parquet.bytes.HeapByteBufferAllocator;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.values.ValuesWriter;
import org.apache.parquet.column.values.plain.FixedLenByteArrayPlainValuesWriter;
import org.apache.parquet.schema.Types;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;

import static io.airlift.slice.Slices.EMPTY_SLICE;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.parquet.ParquetEncoding.PLAIN;
import static io.trino.parquet.ParquetTestUtils.createParquetReader;
import static io.trino.parquet.reader.TestingColumnReader.encodeInt96Timestamp;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_NANOS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_PICOS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
import static java.time.ZoneOffset.UTC;
import static java.time.temporal.ChronoField.NANO_OF_SECOND;
import static org.apache.parquet.format.CompressionCodec.UNCOMPRESSED;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT96;
import static org.assertj.core.api.Assertions.assertThat;

public class TestInt96Timestamp
{
    private static final LocalDateTime[] TIMESTAMPS = new LocalDateTime[] {
            LocalDateTime.of(-5000, 1, 1, 1, 1, 1),
            LocalDateTime.of(-1, 4, 2, 4, 2, 4, 2),
            LocalDateTime.of(1, 1, 1, 0, 0),
            LocalDateTime.of(1410, 7, 15, 14, 30, 12),
            LocalDateTime.of(1920, 8, 15, 23, 59, 59, 10020030),
            LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999),
            LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999999),
            LocalDateTime.of(1969, 12, 31, 23, 59, 59, 999999999),
            LocalDateTime.of(1970, 1, 1, 0, 0),
            LocalDateTime.of(1970, 1, 1, 0, 0, 0, 1),
            LocalDateTime.of(1970, 1, 1, 0, 0, 0, 1000),
            LocalDateTime.of(1970, 1, 1, 0, 0, 0, 1000000),
            LocalDateTime.of(2022, 2, 3, 12, 8, 51, 123456789),
            LocalDateTime.of(2022, 2, 3, 12, 8, 51, 1),
            LocalDateTime.of(2022, 2, 3, 12, 8, 51, 999999999),
            LocalDateTime.of(123456, 1, 2, 3, 4, 5, 678901234)};

    @Test
    public void testVariousTimestamps()
            throws IOException
    {
        testVariousTimestamps(TIMESTAMP_MILLIS);
        testVariousTimestamps(TIMESTAMP_MICROS);
        testVariousTimestamps(TIMESTAMP_NANOS);
        testVariousTimestamps(TIMESTAMP_PICOS);
    }

    @Test
    public void testNanosOutsideDayRange()
            throws IOException, URISyntaxException
    {
        List<String> columnNames = ImmutableList.of("timestamp");
        List<Type> types = ImmutableList.of(TIMESTAMP_NANOS);

        // int96_timestamps_nanos_outside_day_range.parquet file is prepared with timeOfDayNanos values which are
        // outside the [0, NANOSECONDS_PER_DAY] range to simulate data generated by AWS wrangler.
        // https://github.com/aws/aws-sdk-pandas/issues/592#issuecomment-920716270
        // ALl other known parquet writers don't violate the [0, NANOSECONDS_PER_DAY] range for timeOfDayNanos
        ParquetDataSource dataSource = new FileParquetDataSource(
                new File(Resources.getResource("int96_timestamps_nanos_outside_day_range.parquet").toURI()),
                ParquetReaderOptions.defaultOptions());
        ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty());
        ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames);

        SourcePage page = reader.nextPage();
        ImmutableList.Builder<LocalDateTime> builder = ImmutableList.builder();
        while (page != null) {
            Fixed12Block block = (Fixed12Block) page.getBlock(0);
            for (int i = 0; i < block.getPositionCount(); i++) {
                builder.add(toLocalDateTime(block, i));
            }
            page = reader.nextPage();
        }
        assertThat(builder.build()).containsExactlyInAnyOrder(
                LocalDateTime.of(-5001, 12, 31, 4, 22, 57, 193656253),
                LocalDateTime.of(1970, 1, 1, 13, 43, 58, 721344111),
                LocalDateTime.of(1969, 12, 30, 22, 14, 51, 243235321),
                LocalDateTime.of(-1, 4, 2, 22, 35, 10, 668330477),
                LocalDateTime.of(0, 12, 30, 0, 7, 53, 939664765),
                LocalDateTime.of(1410, 7, 15, 2, 18, 26, 329074140),
                LocalDateTime.of(1920, 8, 17, 12, 56, 2, 190285077),
                LocalDateTime.of(1969, 12, 31, 12, 11, 30, 879147442),
                LocalDateTime.of(1969, 12, 30, 6, 26, 40, 679553451),
                LocalDateTime.of(1970, 1, 2, 3, 38, 38, 483312394),
                LocalDateTime.of(1970, 1, 2, 12, 0, 27, 672539248),
                LocalDateTime.of(2022, 2, 4, 7, 55, 30, 455814445),
                LocalDateTime.of(123456, 1, 2, 0, 56, 30, 494898191));
    }

    private void testVariousTimestamps(TimestampType type)
            throws IOException
    {
        int valueCount = TIMESTAMPS.length;

        PrimitiveField field = new PrimitiveField(type, true, new ColumnDescriptor(new String[] {"dummy"}, Types.required(INT96).named("dummy"), 0, 0), 0);
        ValuesWriter writer = new FixedLenByteArrayPlainValuesWriter(12, 1024, 1024, HeapByteBufferAllocator.getInstance());

        for (LocalDateTime timestamp : TIMESTAMPS) {
            long expectedEpochSeconds = timestamp.toEpochSecond(UTC);
            int nanos = timestamp.get(NANO_OF_SECOND);
            writer.writeBytes(encodeInt96Timestamp(expectedEpochSeconds, nanos));
        }

        Slice slice = Slices.wrappedBuffer(writer.getBytes().toByteArray());
        DataPage dataPage = new DataPageV2(
                valueCount,
                0,
                valueCount,
                EMPTY_SLICE,
                EMPTY_SLICE,
                PLAIN,
                slice,
                slice.length(),
                OptionalLong.empty(),
                null,
                false,
                0);
        // Read and assert
        ColumnReaderFactory columnReaderFactory = new ColumnReaderFactory(DateTimeZone.UTC, ParquetReaderOptions.defaultOptions());
        ColumnReader reader = columnReaderFactory.create(field, newSimpleAggregatedMemoryContext());
        PageReader pageReader = new PageReader(new ParquetDataSourceId("test"), UNCOMPRESSED, List.of(dataPage).iterator(), false, false, Optional.empty(), -1, -1);
        reader.setPageReader(pageReader, Optional.empty());
        reader.prepareNextRead(valueCount);
        Block block = reader.readPrimitive().getBlock();

        for (int i = 0; i < valueCount; i++) {
            LocalDateTime timestamp = TIMESTAMPS[i];
            long expectedEpochSeconds = timestamp.toEpochSecond(UTC);
            int nanos = timestamp.get(NANO_OF_SECOND);
            int precisionToNanos = Math.max(0, 9 - type.getPrecision());
            int expectedNanos = (int) Timestamps.round(nanos, precisionToNanos);
            if (expectedNanos == 1_000_000_000) {
                expectedEpochSeconds++;
                expectedNanos = 0;
            }

            DecodedTimestamp actual = toDecodedTimestamp(type, block, i);
            assertThat(actual.epochSeconds()).isEqualTo(expectedEpochSeconds);
            assertThat(actual.nanosOfSecond()).isEqualTo(expectedNanos);
        }
    }

    private static DecodedTimestamp toDecodedTimestamp(TimestampType timestampType, Block block, int position)
    {
        if (timestampType.isShort()) {
            long value = TIMESTAMP_MICROS.getLong(block, position);
            return new DecodedTimestamp(floorDiv(value, MICROSECONDS_PER_SECOND), floorMod(value, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND);
        }
        Fixed12Block fixed12Block = (Fixed12Block) block;
        return new DecodedTimestamp(
                floorDiv(fixed12Block.getFixed12First(position), MICROSECONDS_PER_SECOND),
                floorMod(fixed12Block.getFixed12First(position), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND + fixed12Block.getFixed12Second(position) / PICOSECONDS_PER_NANOSECOND);
    }

    private static LocalDateTime toLocalDateTime(Fixed12Block block, int position)
    {
        long epochMicros = block.getFixed12First(position);
        int picosOfMicro = block.getFixed12Second(position);

        return SqlTimestamp.newInstance(9, epochMicros, picosOfMicro)
                .toLocalDateTime();
    }
}
